{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DPU example: MNIST Classifier\n",
    "\n",
    "This notebook shows how to deploy Convolutional Neural Network (CNN) model for \n",
    "hand-written digit recognition. The network was trained on the MNIST dataset,\n",
    "quantized using Vitis AI compiler tools, and deployed on the DPU.\n",
    "\n",
    "Compared to the other 3 notebooks delivered in this folder, this notebook\n",
    "shows how to deploy a **user-trained** DPU model on PYNQ image; i.e., \n",
    "the model used in this notebook does not come from the model zoo.\n",
    "Refer to [train your own DPU models](https://github.com/Xilinx/DPU-PYNQ/tree/master/host#train-your-own-dpu-models-from-scratch)\n",
    "to get more information.\n",
    "\n",
    "## 1. Prepare the overlay\n",
    "\n",
    "We will download the overlay onto the board. Then we will load the \n",
    "corresponding DPU model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "\n",
       "try {\n",
       "require(['notebook/js/codecell'], function(codecell) {\n",
       "  codecell.CodeCell.options_default.highlight_modes[\n",
       "      'magic_text/x-csrc'] = {'reg':[/^%%pybind11/]};\n",
       "  Jupyter.notebook.events.one('kernel_ready.Kernel', function(){\n",
       "      Jupyter.notebook.get_cells().map(function(cell){\n",
       "          if (cell.cell_type == 'code'){ cell.auto_highlight(); } }) ;\n",
       "  });\n",
       "});\n",
       "} catch (e) {};\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "\n",
       "try {\n",
       "require(['notebook/js/codecell'], function(codecell) {\n",
       "  codecell.CodeCell.options_default.highlight_modes[\n",
       "      'magic_text/x-csrc'] = {'reg':[/^%%microblaze/]};\n",
       "  Jupyter.notebook.events.one('kernel_ready.Kernel', function(){\n",
       "      Jupyter.notebook.get_cells().map(function(cell){\n",
       "          if (cell.cell_type == 'code'){ cell.auto_highlight(); } }) ;\n",
       "  });\n",
       "});\n",
       "} catch (e) {};\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "\n",
       "try {\n",
       "require(['notebook/js/codecell'], function(codecell) {\n",
       "  codecell.CodeCell.options_default.highlight_modes[\n",
       "      'magic_text/x-csrc'] = {'reg':[/^%%swig/]};\n",
       "  Jupyter.notebook.events.one('kernel_ready.Kernel', function(){\n",
       "      Jupyter.notebook.get_cells().map(function(cell){\n",
       "          if (cell.cell_type == 'code'){ cell.auto_highlight(); } }) ;\n",
       "  });\n",
       "});\n",
       "} catch (e) {};\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pynq_dpu import DpuOverlay\n",
    "overlay = DpuOverlay(\"dpu.bit\")\n",
    "overlay.load_model(\"dpu_mnist_classifier_0.elf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's import some libraries as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from time import time\n",
    "import numpy as np\n",
    "import mnist\n",
    "from dnndk import n2cube\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load test data\n",
    "\n",
    "The `mnist` package enables the following data for users:\n",
    "\n",
    "* `test_images()`: returns test images stored as a numpy array. \n",
    "Each image is a grayscale 28x28 pixels, representing a digit between 0 and 9.\n",
    "* `test_labels()`: returns a list of the true labels stored as numpy array.\n",
    "\n",
    "\n",
    "There are 2 pre-processing steps we need to do to the test images \n",
    "before we can use it:\n",
    "\n",
    "1. The raw numpy array delivered by `mnist` has a data type of \n",
    "uint8 (data ranges from 0 to 255); we need to normalize the elements to \n",
    "floating-point numbers ranging from 0 to 1.\n",
    "2. The DNNDK API will expect each input sample to have 3 dimensions; \n",
    "so we need to expand the original numpy array."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of test images: 10000\n",
      "  Dimension of each picture: 28x28\n"
     ]
    }
   ],
   "source": [
    "raw_data = mnist.test_images()\n",
    "normalized_data = np.asarray(raw_data/255, dtype=np.float32)\n",
    "test_data = np.expand_dims(normalized_data, axis=3)\n",
    "test_label = mnist.test_labels()\n",
    "\n",
    "print(\"Total number of test images: {}\".format(test_data.shape[0]))\n",
    "print(\"  Dimension of each picture: {}x{}\".format(test_data.shape[1],\n",
    "                                                  test_data.shape[2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAACPRJREFUeJzt3W+IlWkZx/Hfte7IurDoOoVMmQsqBgoaMrJhUEuJLqVjkqXgUhT1NsLxD4iLEYkRQRGBLwILx14M08ikaUkvUoR1XIzUEnwzoG4x4k625mSSf65ezCwMNuce95zjmZnz+35AcPc6z/PcB/brPZ5nzzmRmQLg57nJXgCAyUH8gCniB0wRP2CK+AFTxA+YIn5jEXE6Ir7Z6GMxNRB/E4iIaxGxZrLXUUlEfC0i/hQR/4qIv0XEDyPi+clelzviRyO8KOk7kj4k6VVJn5O0Y1JXBOJvZhHxckT8NiLejYh/jv5+/hMPWxQRb0fEnYj4TUTMHXP8JyPirYh4LyIuRcRr1awjMw9m5tnM/G9m/l3SryR9qvpnhnog/ub2nKRfSHpF0gJJ/5H0syce81VJ35D0EUkPJf1UkiLio5JOSPq+pLka2al7I+LDT14kIhaM/gGx4CnX9WlJVz7ws0FdEX8Ty8x/ZGZvZt7LzLuS9kv6zBMP68rMv2bmvyW9KekrETFD0huSTmbmycx8nJl/kHRB0ufHuc6NzJyTmTcmWlNEfF1Su6Qf1fj0UCNedGliEfGipB9Lel3Sy6P/+qWImJGZj0b/+Z0xh1yX1KKRv5u/IunLEbFhzLxF0h9rWM8XJf1A0prMHKr2PKgP4m9unZI+LunVzLwZEZ+Q9GdJMeYxHxvz+wWSHkga0sgfCl2Z+a16LCQiXpf0c0lfyMy/1OOcqA0/9jePloh4Ycyv5yW9pJG/5783+kLevnGOeyMilo7+lPA9Sb8e/angiKQNEbEuImaMnvO1cV4wnFBEfFYjL/J9KTPfrvoZoq6Iv3mc1Ejo7//6rqSfSJqlkZ28X9LvxzmuS9IvJd2U9IKkb0tSZr4jaaOkPZLe1chPAjs1zn8zoy/4DRde8HtT0mxJJ0cfNxwRv6vqWaJugg/zADyx8wOmiB8wRfyAKeIHTDX0Pn9E8Ooi8IxlZkz8KHZ+wBbxA6aIHzBF/IAp4gdMET9givgBU8QPmCJ+wBTxA6aIHzBF/IAp4gdMET9givgBU8QPmCJ+wBTxA6aIHzBF/IAp4gdMET9giq/ongZ27NhRnM+aNavibPny5cVjN2/eXNWa3nfw4MHi/Ny5cxVnXV1dNV0btWHnB0wRP2CK+AFTxA+YIn7AFPEDpogfMBWZjfvWbL6ie3zd3d3Fea334ifTwMBAxdmaNWuKx964caPey7HAV3QDKCJ+wBTxA6aIHzBF/IAp4gdMET9givfzN8Bk3se/evVqcX7q1KnifOHChcX5hg0bivNFixZVnG3btq147IEDB4pz1IadHzBF/IAp4gdMET9givgBU8QPmCJ+wBT3+eugvb29ON+0aVNN579y5Upx3tHRUXE2NDRUPHZ4eLg4nzlzZnHe399fnK9YsaLirLW1tXgsni12fsAU8QOmiB8wRfyAKeIHTBE/YIpbfXXQ1tZWnEeUP0l5olt569atK84HBweL81p0dnYW50uXLq363CdOnKj6WNSOnR8wRfyAKeIHTBE/YIr4AVPED5gifsAU9/nr4Pjx48X54sWLi/O7d+8W57dv3/7Aa6qXrVu3FuctLS0NWgnqjZ0fMEX8gCniB0wRP2CK+AFTxA+YIn7AFPf5G+D69euTvYSKdu7cWZwvWbKkpvOfP3++qhmePXZ+wBTxA6aIHzBF/IAp4gdMET9givgBU5GZjbtYROMuBknS+vXri/Oenp7ifKKv6L5161ZxXvo8gDNnzhSPRXUys/xFEaPY+QFTxA+YIn7AFPEDpogfMEX8gCniB0zxfv4m197eXpxPdB9/It3d3cU59/KnLnZ+wBTxA6aIHzBF/IAp4gdMET9gilt9TaCvr6/ibO3atTWd+/Dhw8X53r17azo/Jg87P2CK+AFTxA+YIn7AFPEDpogfMEX8gCk+unsaaGtrK84vXbpUcdba2lo8dmhoqDhfvXp1cT4wMFCco/H46G4ARcQPmCJ+wBTxA6aIHzBF/IAp4gdM8X7+aaC3t7c4n+hefsmRI0eKc+7jNy92fsAU8QOmiB8wRfyAKeIHTBE/YIr4AVPc558COjo6ivOVK1dWfe7Tp08X5/v27av63Jje2PkBU8QPmCJ+wBTxA6aIHzBF/IAp4gdMcZ+/ASZ6v/2ePXuK85aWlqqvffHixeJ8eHi46nNjemPnB0wRP2CK+AFTxA+YIn7AFPEDprjV1wCdnZ3F+apVq2o6f19fX8UZb9lFJez8gCniB0wRP2CK+AFTxA+YIn7AFPEDpiIzG3exiMZdbAq5f/9+cV7LW3Ylaf78+RVng4ODNZ0b009mxtM8jp0fMEX8gCniB0wRP2CK+AFTxA+YIn7AFO/nbwJz586tOHvw4EEDV/L/7ty5U3E20dom+v8fZs+eXdWaJGnOnDnF+fbt26s+99N49OhRxdnu3buLx967d68ua2DnB0wRP2CK+AFTxA+YIn7AFPEDpogfMMV9/iZw+fLlyV5CRT09PRVnE33WwLx584rzLVu2VLWmqe7mzZvF+f79++tyHXZ+wBTxA6aIHzBF/IAp4gdMET9gio/uboCjR48W5xs3bmzQSrw8fPiw4uzx48c1nfvYsWPF+YULF6o+99mzZ4vz/v7+4pyP7gZQRPyAKeIHTBE/YIr4AVPED5gifsAU9/mngF27dhXntX6Fd8myZcuK82f5ttlDhw4V59euXavp/L29vRVnV69erencUxn3+QEUET9givgBU8QPmCJ+wBTxA6aIHzDFfX6gyXCfH0AR8QOmiB8wRfyAKeIHTBE/YIr4AVPED5gifsAU8QOmiB8wRfyAKeIHTBE/YIr4AVPED5gifsAU8QOmiB8wRfyAKeIHTBE/YIr4AVPED5gifsAU8QOmiB8wRfyAKeIHTBE/YIr4AVPED5gifsAU8QOmiB8wRfyAKeIHTBE/YCoyc7LXAGASsPMDpogfMEX8gCniB0wRP2CK+AFTxA+YIn7AFPEDpogfMEX8gCniB0wRP2CK+AFTxA+YIn7AFPEDpogfMEX8gCniB0wRP2CK+AFTxA+Y+h++G6O2KV3iBgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f5e206a58>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(test_data[1,:,:,0], 'gray')\n",
    "plt.title('Label: {}'.format(test_label[1]))\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Set up DPU\n",
    "\n",
    "We can set the DPU kernel name, and input/output nodes.\n",
    "These pieces of information can be found from the Vitis AI compiler output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "KERNEL_NAME = \"mnist_classifier_0\"\n",
    "KERNEL_CONV_INPUT = \"conv2d_1_convolution\"\n",
    "KERNEL_FC_OUTPUT = \"output_logits_MatMul\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use `n2cube` to attach to the kernel driver. Load the kernel and create a task for it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "n2cube.dpuOpen()\n",
    "kernel = n2cube.dpuLoadKernel(KERNEL_NAME)\n",
    "task = n2cube.dpuCreateTask(kernel, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get input and output dimensions for the task, we will use these for later computations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_len = n2cube.dpuGetInputTensorSize(task, KERNEL_CONV_INPUT)\n",
    "size = n2cube.dpuGetOutputTensorSize(task, KERNEL_FC_OUTPUT)\n",
    "channel = n2cube.dpuGetOutputTensorChannel(task, KERNEL_FC_OUTPUT)\n",
    "conf = n2cube.dpuGetOutputTensorAddress(task, KERNEL_FC_OUTPUT)\n",
    "outputScale = n2cube.dpuGetOutputTensorScale(task, KERNEL_FC_OUTPUT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Run DPU to make predictions\n",
    "\n",
    "We can now classify a couple of digit pictures. For each picture, \n",
    "the classification result (shown as 'Prediction') is displayed on top of \n",
    "the picture. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA18AAABfCAYAAAAXrn30AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAHcZJREFUeJzt3XmUVNUV7/HvwVZRQQQUI2EQJU4gjlFEiSgsRZAxGoNGMUaQBIWoCILGARXeAiQaojgkiigiIKIQRAkJDqDGFyJgEMJDQ4MyCKLIIJJu7vvj1r5V1RPVNdyu4fdZi9VNVfWt07tv3apz9jn7OM/zEBERERERkcyqVdMNEBERERERKQTqfImIiIiIiIRAnS8REREREZEQqPMlIiIiIiISAnW+REREREREQqDOl4iIiIiISAiysvPlnDvWOec554oi/5/nnOubxHGaOed2OucOSH8rc4timhmKa/oppumnmGaG4pp+imn6KaaZobimX8HE1PO8pP4Ba4HvgJ3AZuBZoE6yxytz7GMBDyhKok2d0tGGaj5vs0gcYv95wO2KadLtbQRMBTYA24HFwLlJHktxjX/uB4CPgRLgPsU0bW1eCOwGViXTDsW00jZcGGn7g0n+vOIa/9x6/Wcutkmfq4ppueduB3wI7ACWAxckeRzFtfJ47ATmK6Ypt7na52qqma9unufVAc4EfgzcXfYBzpeVGbZ08Txvned5dewfcCqwD5iZxOEUU18d4P8CZwENgOeAuc65OkkeT3GNWgMMBeameBzFNGoq8BHQELgLeNk5d1QSx1FMYzjnDgQeBf6R4qEU1yi9/jMgTeeqYgo45xoAs4GxwBHAGGCOc65+kodUXON1i/nMekkqx6DAY5rsuZqWoHie9wUwD2gdacxbzrmHnHOL8UeCj3PO1XPO/dk5t9E594Vz7kFLBzrnDnDOjXPObXXOfQZ0LfPLveWcuzHm//2ccyudczucc5845850zj2Pn4GaE0k1Dq0gfdnYOTfbObfNObfGOdcv5pj3OeemO+cmR467wjl3dpIhuQ54x/O8tUn+fMHH1PO8zzzPG+953kbP80o9z3sKOAg4MdmYRo5b0HGNxOA5z/Pm4Y/SpKzQY+qcOwH/Dehez/O+8zxvJn5m4aeKacrX1NuB+fjZxJQprnr958K5qpjSDtjsed6MyPv/C8AWoHeSIQUU10xQTJM8V1NIs60lkuIDmgIrgAci/38LWAe0AoqAA4FXgSeBw/CnlH0I3BR5/AD8C1ZT/CzHQmLSjpHj3Rj5/krgC/yetgNaAs0rSjtSJn0JvA08DtQGTo8EqGPkvvuAPUAX4ABgNPBBzLEeBx5PMDafAtcrpmmN6emRY9VTXNMTV+AFUpt2pJj69/UCVpa57Y/ABMU0+fMUaA6sxs+CTyK1aYeKq17/WX2uKqZxsegGfFLmtv8H/F5xTflcXYs/VXAL/mDBaYpp+OdqtS+6ZYK/E/gGKI407pCYYI2MeezRwPd2f+S2PsDCyPd/BwbE3HdJFcF/Exi8vxOibPAjf9hSoG7M/aOBSTHBXxBz3ynAd0nEpX0kLtWe/6qYVhqXw/EzCcN1rqY1rql++FJM/cdeS8yFOnLbQ3ZsxTS58xR4Dbgq8v0kUut8Ka7l26DXfxadq4pp3PM2jMShD/6H9774SzmeVFxTPlfPBw4BDgWGA5uAIxTTcM/VIlLT0/O8BZXctz7m++aRRm10ztlttWIe07jM44ureM6m+Jml6moMbPM8L3aqRTEQm1rcFPP9bqC2c67I87ySajxPX2Cm53k7k2gjKKZxnHOHAHPwP9yOTqKNRnFNP8XUtxN/gCDW4SQ3rUsxBZxz3fDfKKcl0a6KKK7pp5iS9nNVMQU8z/vKOdcDGAc8hv+hewHweRLtBMU14Hne4pj/jnZ+JcH2+J+zqkMxJflzNdXOV5Vtivl+PX7P98hKfpGN+EE1zao47nrg+ASes6wNQAPnXN2YP0Az/BRmWkQ6ClfiT0PKhIKKqXPuYPx09RfATek4ZiUKKq4hKaSYrsCf1x577NOAF9Nw7FiFFNOOwNnOOXtDrAeUOudO9TyvRxqOH6uQ4hqWQoppWOdqIcUUz/Pexp9eRmTdzqfAw+k4dtmnivk+7+NaSVvcfh9V/WOavI9pMudqKFVIPM/biD+39GHn3OHOuVrOueOdcxdGHjIdGOSca+L8CiF3VnG4PwFDnHNnOV9L51zzyH2bgeMqacN64D38nn5t51wb4FfAlDT8iqYXfvpxYRqPWaF8j6nzK0e9jF/O9DrP8/alesxE5HtcwY+tc642/uu/KPIcGdsLI99j6nneamApcG/k2L2ANiRX7TTR58zrmAK/A07An5t/On41qaeBX6bh2JUqgLjq9Z8H52oBxBTn3BmRc/Vw/KzC557nvZmOY1cm3+Pq/L2vznfOHRQ59h3Akfhb+WREvscUkjtXwywBeR1+tbpPgK/xP1gfE7nvafxU3TLgX8ArlR3E87wZ+OspXsSf1vMq/iI98Odw3u2c+8Y5N6SCH++DPw90AzALvzrZXxNpvHPuCefcE/t5WF9gsheZCBqCfI5pO+By/Pm/3zi/gs1O51z7RI6donyOq/0O30We467I99cmcuwU5HtMf44/heFr4P8AV3ietyWRY6cgb2Pqed4Oz/M22T/8c3SX53nbEjl2ivI2rjG/g17/uX+u5m1MI4YCW/GzHceQuRlFZeVzXOsCEyO/1xdAZ+Ayz/O+SuTYKcjnmEIS56oLr58gIiIiIiJSuPJ68zMREREREZFsoc6XiIiIiIhICNT5EhERERERCYE6XyIiIiIiIiFQ50tERERERCQEmdxkuRznXM6XVvQ8L92b0aVEMU0/xTQzFNf0U0zTTzHNDMU1/RTT9FNMM0NxjafMl4iIiIiISAjU+RIREREREQmBOl8iIiIiIiIhCHXNl2SnIUOGAHDIIYcA0KZNGwCuuOKK4DETJ04E4P333wfg+eefD7OJIiIiIiI5T5kvERERERGREDjPC68AiaqdpF+qMZ02bVpchmt/Pv30UwA6derEunXrUnnqQL7FtLpOOOEEAFatWsXgwYMBmDBhQkrHzLaYQmbiethhhwEwduxYbrrpJgCWLFkCwJVXXglAcXFx2p4v2+Kqa2r6KaaZobimn2KafjUV0/r169OsWbMK7ysuLubWW28F4N///jcAq1evBmDZsmXlHp9tMQWdq2Up8yUiIiIiIhICrfkqUNOmTQOoMOu1atUqAN58802OO+44ALp16wbA8ccfD8A111zD6NGjw2hq3jvjjDMA2LdvH59//nkNtya3HHPMMQD069ePffv2AXDWWWcBcPnllwPw2GOP1UzjcsSZZ54JwCuvvALAscceW62fv+SSSwBYuXIlAOvXr09f4wqAXVtnz57NzTffDMATTzwBQGlpaY21K9s0atSI6dOnA/Dee+8B8NRTTwGwdu3apI5Zr149fvKTnwDwxhtvAPC///0vxZaK7F/Xrl0B6N69OwAdOnSgZcuWFT529erVNG/eHICDDz447r4DDjggg62UTFHnq8CcffbZAPTq1Su4bcWKFUD0IrB161YAdu7cyUEHHQTABx98AMBpp50GQMOGDcNpcAE4/fTTAdi1axezZs2q4dbkhqOOOgqA5557roZbkvsuvfRSoPybeqKs83DDDTcA8POf/zw9Dctzdg19/PHHg9v++Mc/AvDMM88A8N1334XfsCxTv359wH+fqlevHgCbN28GUut0gT9F2a4lNmizZs2aVJqbsw4//HCAYFC1devWdOrUCVCHNFU2aD1w4ED69esHRAucObf/mWy2NEHyh6YdioiIiIiIhCArM182Fc5GCDZs2ADAnj17mDJlCgCbNm0CCneUKlk2TctGW1asWBGMfG/cuLHc42+//XYATjnllLjb586dm8lmFoTWrVsDBFONVL5//wYNGgRAz549ATjnnHMqfaxNJ6pVq1awKPmdd97JcAtzS1FREV26dEnpGFbg5LbbbgP8Iii7du1KuW35zs7PJk2aBLdNnToV8N/rCt2RRx4JRKfIN2jQIMgS3nLLLSkd++677wagRYsWQaGeQv0scc011wDw0EMPAdC0adPgPsuGffXVV+E3LI/Ya9wKaiXKloDY7CSpWMuWLYPrhc3q6tChA+Av57Bp3IsXLway47WuzJeIiIiIiEgIsjLzNWbMGKDihd82SrVjxw4g+REBK2wwZswY/vnPfyZ1jFw0Z84cgGBh544dO9i2bVulj7f1GwceeGDmG1dgTjrpJCBaLt1GeKVyv//97wGC4hpV6d27d/DVys1fddVVQDRbU+guuugizjvvPCB63a0uW5Nj2fFDDz1Uma8q2Nq6u+66q9x9lv0OcwuYbGWFYGwEG2DkyJEpHbNVq1ZAdEbHrFmzCvq626RJEx555BEgugYx9tyzLU9sdkZVnxUKnWVeBg8eHGRYrIjL999/D8D27duDa6O978+fPx/wS8j/4x//AOCjjz4Coms+dT2NV3bWUO/evYP4V+Tcc88FoKSkBID//Oc/ACxatCjIRu7duzdj7a2IMl8iIiIiIiIhyMrMl631atOmDRAtYXzyySeXGw1r27Yt4Jc3jp2rHKukpIQtW7YA0TVPZt26dQWV+TKJbDx7xx13lKuyYyMz9lWSN3ToUCD6tyjE87A6Xn/9dWrV2v94ka1P2LlzJwDNmzenRYsWAHz44YeAyvPayOHUqVODjdNHjRqV1LF69OiRtnYVglNPPRWIVtczJSUlzJs3ryaalFUaNWoEwE9/+tO423/1q18F7+PVZRmvBQsWxN0+a9asYBZNIRoyZAgNGjSo9H6bKdC5c2cgui5swoQJoWcKslXZDNZpp50WV00aotWizzzzzKBCp22obLOwEpnNUcjatGnDwIEDgeh5aWsSAb744gsA3n33XQD++9//Av7nLJvpYmvE7Zzv0qVLsB7c1oWFJSs7X3/729/ivhpL4UJ0qouV6V6yZAk//vGPKzzenj17gt3ArSNnwbcPHhJl+yONHDkyKDX/5ZdfAjB8+HAAdu/eXTONywM2ndbK/tu5qakFFbvwwgsBOPHEE4M3qIreqOziaW+C27dvB+Diiy8uN8Xr17/+NQATJ07MTKOznBUcOOyww4IPVtZZTZRdQ+3vow8PiSnbqTB23ha6hx9+GIBf/OIXQHSK8IwZM5I+Zvv27QE4+uijAZg0aRIAL7zwQtLHzGW2Z9Qvf/nL4Lbly5cD0TL+VmYeoqX5hwwZAsCUKVOComeFyj4bvfjii0B0G55Ro0aV6+Sb2K0R1q1bl9kG5oknn3wS8AtplJ1aaH2Ejz/+mBEjRgDlixW1a9cueL+3LTys37B58+ZgH9CZM2cCJD3AU12adigiIiIiIhKCrMx8JeLrr78GYOHChcFtZTNlsWy00TJmH3/8MaAiBxWxjIyN7EA0Tm+//XaNtCmfWKbAhDXSkmssQ/jSSy8BVLig1qZszpw5k/vvvx8on5UtLi6mf//+QHRzZisuUbt27WBj20LYSNS28bDy8mvWrEl6uqtlEy3j9dZbbwHwzTffpNjK/GYl5o1N36qoAEchsoIPdl7ZVjPVneZmm9iOGDGC3/zmN3HHtg3BC5WN/NetWzeYpmXvS7Vr1wagT58+QTbBNgn+wQ9+AMBrr73GZZddBhRmEY46deoEs4BsptDWrVsBGDdunGYGpcDOP1uWceONNwL+9kj2WclmrIwdOxaoetZQw4YNg2UG9913HxCdRWcZ4JqgzJeIiIiIiEgIcjbzVR2NGjUKNme0BftWsrYQR20q8+qrrwJwySWXBLdNnjwZiK4RkdTZgnuTbInvfFdU5F+eKsp4WQbWtkKwUceKFBcXM3r0aADGjx8P+OXQwY/97NmzgcJY/3nllVcC0d/frovVdeyxxwabs5aWlgLw4IMPAoWRQUxWu3btaNeuXdxtNmq7dOnSmmhS1uvatSvgr4mzrGpVazUtg1O2KBfAyy+/nKFW5hbb7sDzvGD7DmNrZp599tngenHcccfFPWb37t0FXXCjZ8+e3HnnnUB07ZatK7S1xpIce93ecccdgJ/xAr+ghs1gs8JZFbEslxXgmzx5Mq+//joQnflmnHPB9h5hz9hQ5ktERERERCQEBZH5GjhwYLDWw9aK2SZrEi2/byOyNiq2devWYDS7upXQpGJt27YNKkzZRop//etfa7JJOcXWJ9majaoyXrEsu2XZmsoqo+Yrq1YWmwWA5Ks99u/fP8hIWgXZ2PW3UrGKzrtCrbhZmUcffRTwNwAHaNy4MeCvlbNR8O7du1f68/aY2M2CP/vsM4BgDVOh69OnT/C9ZRZt5kssW/9d1gcffFDQnwlis9f2Pm4l4yU1lrmyGRWmpKQk2CzZ1i6fdNJJwf22IfXJJ58c93Xr1q1BldOyNm/eXGMzNvK683X++ecDBOlh8NPF4O8mLj4rsWk73JsXXnihIKZihalTp05BiW5b9Fm2NKrEi93byy6+1WUfyOxYsce0RbjXXnttki3Mfjag8sMf/hDw9/dKhS3AB11LqyP2w2wiU+gKkZWWt30+rThE586dg6lItvD+ueeeK/fzNo3I9u8BeO+994DCmFqcCHv9d+/ePRgQsA+yNi2+V69ewTQtO1ft//369Qvi/Mknn4TX8CxhH/4hugfavffeC/jFSDSFOHl///vfgehgnm150KxZM/7whz8A8QMr4HfUKtu7M7bjZUV8Zs2aBcCgQYPYuHFjGlufOE07FBERERERCYEr24PM6JM5F96TEd2Nffjw4UEZeiuxnGyK0fM8l57WpUeqMe3evTvTp08H4MADDwSiJaN79OgRytSCfItpVWbMmBEsGrWvNgqTTtkWU6h+XMeNGwfA4MGDg9vsHK2uW265BYgW3LDM1759+4IR30RGxbMtronG1MpuW1lpi+NFF11UraJDjRo1AogbLRw0aBBAsFlldeVqTKvjggsuAPxCMXbu2TYJtqVCOmVbTCG8938rDrFmzRrAL2Ry6aWXAqlv65FtcU02pjb7Ys2aNcGU5Iqma9pmwQMHDgTgL3/5CwA/+tGPePrppwEYMGBAMk0I5GJMPc+rdFP5ffv28cQTTwD+9Ezwszbgx3vFihVxj2/VqhUA77//ftqmLmZbTCH5c/WII44A/BlsNpvtq6++AqLFTg4++OBgk+tzzjmn0mPZ38WmH1e3yEY646rMl4iIiIiISAjycs2XjfLaXNy9e/cG83FVBtln67tGjBhRLptg85ULeUFtutnmlO3btw+KvWQi45VPunXrltLPW5GdU045pdKF9lu2bCmIa4ItRrbsnmVd586dG2QDK9K6dWsgmk2wLE3s6HhlI8ASZdfb2LWGKrSTGffccw8QPUeHDRumjezLsGz3z372s6D8vmXAzIQJExg2bBgQXZf8yiuvAH4WwrKJtv6zkNbTjRs3jttuu63C+2rVqhVs6m1fE7Fly5Zg1pFtoSLR7FRs7YaK2LZIZTNfO3bsCP5WkyZNAsoX86gJynyJiIiIiIiEIC8zX1YR6YwzzgD8qnJW7Uh8t99+OxBf+thKzVqWUNLn+uuvB/w1M/PmzavZxhSIu+66C4iuV4i1du1aAPr27RvMGy8E9tq29R1du3atsvKhlfK3LEJFG17baKJULrY6mo3kPvnkkzXVnLxkGwJfd911gD/iDdH1IVLeggULgnPz6quvBqLn5z333FOuEu8DDzwA+GW8rdy/ZRr79u0bSpuzwZ133sm0adMAePHFFwEoKvI/Tjdt2jQuw52oo446Kvhb3H333UB043qp2tChQyvNFg4YMCDl6r6ZkFedL9uv4ne/+x0A3377LQAjR46ssTZlq4pS5jfffDOg6YaZ0Lx58+B722tOMsN2sz/xxBMrfYyVR160aFEobcoWq1atAvzpRuCX8W7ZsmWlj7cpScZKe9t+aRCd0ijlNWnSBIh+sIXofkC2Z52kx2WXXRb3fysO8a9//asmmpMzrKiGfa2KvdanTZsWdL5sPzYr4lGdAj65qrS0NHj9nnDCCXH3dezYMVjKYduYJLqvpA2KnXXWWWlqaX678cYbAb+zap1fY4VNbKpsttG0QxERERERkRDkTearYcOGwQZsttmajYBbuU+pmo1cVVWAYPv27cFjbHSn7ELdI444otLFqKWlpcEi3t27d6fc5lxx+eWXB9/PmTOnBluSO8pujAzlR7efeuopABo3bhzcFltGvjKpFvPIF0uXLq3WhqCfffZZudusKIc2Wy6vXbt2QPw5bNO7Jb3s2rBr1y4AHn744ZpsTl6bPn16kPm66qqrgOjMmUKfaWTbGkF0g3DLfJWUlPDss88CBKX6f/vb3wLx2XGpmhXVsNd4nTp1gvts5pZtgfD999+H3LrEKPMlIiIiIiISgpzPfFmW64033qBFixZAtOSprf2SxCxfvny/j5kxYwbgb7J69NFHA9GRr0Rt2rQJiG6Cnc9sc1UrNS+JmzhxIgBjxowJbrN1HGWzWhVluSq6zTZZlORYNtK+gjJeVbES82br1q08+uijNdSa/DVgwIDg/ejLL78EtNYrk/bt2xdcl3v06AFEi/m89NJLrF69usbalk3mz58PRD/rFBUV0a9fP4BgrW2HDh3K/Vy6NlvOVzZzpW7dusFtlvG2jOzixYvDb1g1KPMlIiIiIiISgpzPfNkGf7HVYWy9USFt+lddth7ORq0SZeV8K1JSUgLEZxxmz54NxFf2evfdd6v1nLmsV69eQDRD+9FHH/HOO+/UZJNyhlUpsq0jbNPkRNnGqitXrqR///6An7GV5FnJ+dhNlqVythGtWbduXbBuVtJnwIABwTk5d+7cuPvq1q1L/fr1AQpqW4lMs7WiVmp+7NixAIwaNYprr70WUCXUlStXAv4aOYhWmYVolUhTWloanLv721C4UFmma+jQoeXumzJlCkCwUXW2y9nOl5XutrQuRD+k2dQkqVzv3r0B/yS2whlltWrVqtIphc8880ywV5KZOXMmEC1nXcgOPfRQALp06RJ3+8svv5wVu6vnguLiYoBg/46ePXsyePDghH/epno89thj6W9cgapdu3bwfaF/sKqKXVNtcNDs2bOnyoJGkjq7vtp2CLfeemtQdrqQ9qIKy+TJkwG46aabAP+zhRXdSGQpQz6za6QV1ahTpw5nn3024O/5CdE9J59//vmgNL2UV6dOnWCLmLKfWZcvXx7EOFdo2qGIiIiIiEgIXJjTR5xzaXsyG9UePnx4cJuVn8zk5pWe57n9Pyo86YxpTcnHmNrIzNtvvw1EF4FfffXVoZTYz7aYQnri2rlzZ4BgGqEtvJ09e3ZQdt6KQdgoWTqnGmVbXMN+/VuxnKKiIh544AGAlAtI5GNMbZrxn/70JwCuv/56wM8ShJF9ybaYQmbP1aVLl3Lqqafa8wDRqbF//vOfg3N1/fr1KT1PtsU1m97/mzVrBviZnKlTpwLxm7FXptBialMy27ZtC8D9998PRD8jpEO2xRRSj2v37t157bXXgPLT3jt27MjChQtTOXxC0hlXZb5ERERERERCkHOZLyvdbQUjYjdXU+YrNymm6ZdtMQXFNRPCjqltED5+/Pi0jTTmc0xt8+8HH3wQgCVLloSyBjHbYgqZPVcvuOCCYJ2RFTSyrSq+/vpr9u7dm5bnyba4ZuM1df78+Zx33nkAnHvuuUB0FkJFFNP0y7aYQupxXbZsWZDdNlbkZdiwYakcOmHKfImIiIiIiOSYnKt22L59eyA+4wV+WfmdO3fWRJNERAqCrbGTxGzYsAGAG264oYZbkt8WLVrExRdfXNPNEOCKK65g2bJlQHQj4aoyXyKJaNCgQbCe09bHPfLIIzXZpJTkXOerLHuRd+zYkW3bttVwa0REREQK07fffkuLFi1quhmSZ8aPH8/48eMBggI6ubxvp6YdioiIiIiIhCDnCm7UtGxbyKiYpp9imhmKa/oppumnmGaG4pp+imn6KaaZobjGU+ZLREREREQkBKFmvkRERERERAqVMl8iIiIiIiIhUOdLREREREQkBOp8iYiIiIiIhECdLxERERERkRCo8yUiIiIiIhICdb5ERERERERCoM6XiIiIiIhICNT5EhERERERCYE6XyIiIiIiIiFQ50tERERERCQE6nyJiIiIiIiEQJ0vERERERGREKjzJSIiIiIiEgJ1vkREREREREKgzpeIiIiIiEgI1PkSEREREREJgTpfIiIiIiIiIVDnS0REREREJATqfImIiIiIiIRAnS8REREREZEQqPMlIiIiIiISAnW+REREREREQqDOl4iIiIiISAj+P+3bvtMDVjh/AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f5e1af048>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "num_pics  = 10\n",
    "fix, ax = plt.subplots(1, num_pics, figsize=(12,12))\n",
    "plt.tight_layout()\n",
    "for i in range(num_pics):\n",
    "    n2cube.dpuSetInputTensorInHWCFP32(task, KERNEL_CONV_INPUT, \n",
    "                                      test_data[i], input_len)\n",
    "    n2cube.dpuRunTask(task)\n",
    "    softmax = n2cube.dpuRunSoftmax(conf, channel, size//channel, outputScale)\n",
    "    prediction = softmax.argmax()\n",
    "\n",
    "    ax[i].set_title('Prediction: {}'.format(prediction))\n",
    "    ax[i].axis('off')\n",
    "    ax[i].imshow(test_data[i,:,:,0], 'gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also evaluate on the entire test dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classifying 10000 digit pictures ...\n",
      "Overall accuracy: 0.9858\n",
      "  Execution time: 3.3749s\n",
      "      Throughput: 2963.0354FPS\n"
     ]
    }
   ],
   "source": [
    "total = test_data.shape[0]\n",
    "predictions = np.empty_like(test_label)\n",
    "print(\"Classifying {} digit pictures ...\".format(total))\n",
    "\n",
    "start = time()\n",
    "for i in range(total):\n",
    "    n2cube.dpuSetInputTensorInHWCFP32(task, KERNEL_CONV_INPUT,\n",
    "                                      test_data[i], input_len)\n",
    "    n2cube.dpuRunTask(task)\n",
    "    softmax = n2cube.dpuRunSoftmax(conf, channel, size//channel, outputScale)\n",
    "    predictions[i] = softmax.argmax()\n",
    "stop = time()\n",
    "correct = np.sum(predictions==test_label)\n",
    "execution_time = stop-start\n",
    "print(\"Overall accuracy: {}\".format(correct/total))\n",
    "print(\"  Execution time: {:.4f}s\".format(execution_time))\n",
    "print(\"      Throughput: {:.4f}FPS\".format(total/execution_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make sure to clean up when you are done."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n2cube.dpuDestroyKernel(kernel)\n",
    "n2cube.dpuDestroyTask(task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (C) 2020 Xilinx, Inc"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
